
import torch
import torch.nn as nn

idx2char = ['e', 'h', 'l', 'o']

x_data = [1, 0, 2, 2, 3]
y_data = [3, 1, 2, 3, 2]

one_hot_lookup = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]

x_one_hot= [one_hot_lookup[idx] for idx in x_data]
print(x_one_hot)

input_size = len(idx2char)
seq_size = len(x_data)
batch_size = 1
hidden_size = len(idx2char)
num_layers = 1


inputs = torch.tensor(x_one_hot).view(-1, batch_size,input_size)
inputs = inputs.float()

labels = torch.tensor( y_data  )

class NetGRU(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers):
        super(NetGRU, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.gru = nn.GRU(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)

    def forward(self,inputs):
        h0 = torch.zeros(self.num_layers,batch_size,self.hidden_size)
        output ,hidden = self.gru(inputs,h0)

        return output.view(-1,self.hidden_size),hidden

net = NetGRU(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)

net.train()
epochs = 100
print("*"*80)
for epoch in range(epochs):
    optimizer.zero_grad()
    output,hidden = net(inputs)

    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()
    print("epoch %d, loss %f" %(epoch,loss.item()))



















